{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ch-DCRdcNnLY"
      },
      "source": [
        "# L-BFGS\n",
        "\n",
        "L-BFGS is a classical optimization method that uses past gradients and parameters information to iteratively refine a solution to a minimization problem. In this notebook, we illustrate\n",
        "1. how to use L-BFGS as a simple gradient transformation,\n",
        "2. how to wrap L-BFGS in a solver, and how linesearches are incorporated,\n",
        "3. how to debug the solver if needed.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0kUKnsM4nb5m"
      },
      "outputs": [],
      "source": [
        "from typing import NamedTuple\n",
        "\n",
        "import chex\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import jax.random as jrd\n",
        "\n",
        "import optax\n",
        "import optax.tree_utils as otu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cja5fd5FTlGs"
      },
      "source": [
        "## L-BFGS as a gradient transformation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UbkaJmath6Np"
      },
      "source": [
        "### What is L-BFGS?\n",
        "\n",
        "To solve a problem of the form\n",
        "\n",
        "$$\n",
        "\\min_w f(w),\n",
        "$$\n",
        "\n",
        "L-BFGS ([Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm](https://en.wikipedia.org/wiki/Limited-memory_BFGS)) makes steps of the form\n",
        "\n",
        "$$\n",
        "w_{k+1} = w_k - \\eta_k P_k g_k,\n",
        "$$\n",
        "\n",
        "where, at iteration $k$, $w_k$ are the parameters, $g_k = \\nabla f_k$ are the gradients, $\\eta_k$ is the stepsize, and $P_k$ is a *preconditioning* matrix, that is, a matrix that transforms the gradients to ease the optimization process.\n",
        "\n",
        "L-BFGS builds the preconditioning matrix $P_k$ as an approximation of the Hessian inverse $P_k \\approx \\nabla^2 f(w_k)^{-1}$ using past gradient and parameters information. Briefly, at iteration $k$, the previous preconditioning matrix $P_{k-1}$ is updated such that $P_k$ satisfies the secant condition $P_k(w_k-w_{k-1}) = g_k -g_{k-1}$. The original BFGS algorithm updates $P_k$ using all past information, the limited-memory variant only uses a fixed number of past parameters and gradients to build $P_k$. See [Nocedal and Wright, Numerical Optimization, 1999](https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf) or the [documentation](https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale_by_lbfgs) for more details on the implementation.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2eoJ_CrpmgqK"
      },
      "source": [
        "### Using L-BFGS as a gradient transformation\n",
        "\n",
        "The function {py:func}`optax.scale_by_lbfgs` implements the update of the preconditioning matrix given a running optimizer state $s_k$. Given $(g_k, s_k, w_k)$, this function returns $(P_kg_k, s_{k+1})$. We illustrate its performance below on a simple convex quadratic."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gySpnJ-ch5YT"
      },
      "outputs": [],
      "source": [
        "# Define objective\n",
        "dim = 8\n",
        "w_opt = jnp.ones(dim)\n",
        "mat = jrd.normal(jrd.PRNGKey(0), (dim, dim))\n",
        "mat = mat.dot(mat.T)\n",
        "\n",
        "\n",
        "def fun(w):\n",
        "  return 0.5 * (w - w_opt).dot(mat.dot(w - w_opt))\n",
        "\n",
        "\n",
        "# Define optimizer\n",
        "lr = 1e-1\n",
        "opt = optax.scale_by_lbfgs()\n",
        "\n",
        "# Initialize optimization\n",
        "w = jrd.normal(jrd.PRNGKey(1), (dim,))\n",
        "state = opt.init(w_opt)\n",
        "\n",
        "# Run optimization\n",
        "for i in range(16):\n",
        "  v, g = jax.value_and_grad(fun)(w)\n",
        "  print(f'Iteration: {i}, Value:{v:.2e}')\n",
        "  u, state = opt.update(g, state, w)\n",
        "  w = w - lr * u\n",
        "\n",
        "print(f'Final value: {fun(w):.2e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f-1P4_79rnow"
      },
      "source": [
        "## L-BFGS as a solver\n",
        "\n",
        "L-BFGS is a sample in numerical optimization to solve medium scale problems. It is often the backend of generic minimization functions in software libraries like [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize). A key ingredient to make it a simple optimization blackbox, is to remove the need of tuning the stepsize, a.k.a. learning rate in machine learning. In a deterministic setting (no additional varying inputs like inputs/labels), such automatic tuning of the stepsize is done by means of linesearches reviewed below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6KIZE_pFwdcz"
      },
      "source": [
        "### What are linesearches?\n",
        "\n",
        "Given current parameters $w_k$, an update direction $u_k$ (such as the negative preconditioned gradient $u_k = -P_k g_k$ returned by L-BFGS), a linesearch computes a stepsize $\\eta_k$ such that the next iterate $w_{k+1} = w_k + \\eta_k u_k$ satisfies some criterions.\n",
        "\n",
        "#### Sufficient decrease (Armijo-Goldstein criterion)\n",
        "\n",
        "The first criterion that a good stepsize may need to satisfy is to ensure that the next iterate decreases the value of the objective by a a sufficient amount. Mathematically, the criterion is expressed as finding $\\eta_k$ such that\n",
        "\n",
        "$$\n",
        "f(w_k + \\eta_k u_k) \\leq f(w_k) + c_1 \\eta_k \\langle u_k, g_k\\rangle\n",
        "$$\n",
        "\n",
        "where $c_1$ is some constant set to $10^{-4}$ by default. Consider for example the update direction to be $u_k = -g_k$, i.e., moving along the negative gradient direction. In that case the criterion above reduces to $f(w_k - \\eta_k g_k) \\leq f(w_k) - c_1 \\eta_k ||g_k||_2^2$. The criterion amounts then to choosing the stepsize such that it decreases the objective by an amount proportional to the squared gradient norm.\n",
        "\n",
        "As long as the update direction is a *descent direction*, that is, $\\langle u_k, g_k\\rangle \u003c 0$ the above criterion is guaranteed to be satisfied by some sufficiently small stepsize.\n",
        "A simple linesearch technique to ensure a sufficient decrease is then to decrease a candidate stepsize by a constant factor up until the criterion is satisfied. This amounts to the backtracking linesearch implemented in {py:func}`optax.scale_by_backtracking_linesearch` and briefly reviewed below.\n",
        "\n",
        "#### Small curvature (Strong wolfe criterion)\n",
        "\n",
        "The sufficient decrease criterion ensures that the algorithm does not produce a sequence of diverging objective values. However, we may want to not only reduce a current stepsize but also increase it to ensure maximal speed. Ideally, we would like to find the stepsize that minimizes the function along the current update, i.e., $\\eta_k^* = \\arg\\min_\\eta f(w_k + \\eta u_k)$. Such an endeavor can be computationally prohibitive, so we may select a stepsize that ensures some properties that an optimal stepsize would satisfy. In particular, we may search for a stepsize such that the derivative of $h(\\eta) = f(w_k + \\eta u_k)$ is small enough compared to its derivativeœ at $\\eta=0$. Formally, we may want to select the stepsize $\\eta_k$ such that $|h'(eta_k)| \\leq |h'(0)|$, that is,\n",
        "\n",
        "$$\n",
        "|\\langle \\nabla f(w_k + \\eta_k u_k), u_k\\rangle|\n",
        "\\leq |\\langle \\nabla f(w_k), u_k\\rangle|.\n",
        "$$\n",
        "\n",
        "See Chapter 3 of [Nocedal and Wright, Numerical Optimization, 1999](https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf) for some illustrations of this criterion. A linesearch method that can ensure both criterions require some form of bisection method implemented in optax with the {py:func}`optax.scale_by_zoom_linesearch` method. Several other linesearch techniques exist, see e.g. https://github.com/JuliaNLSolvers/LineSearches.jl. It is generally recommended to combine L-BFGS with a line-search ensuring both sufficient decrease and small curvature, which the {py:func}`optax.scale_by_zoom_linesearch` ensures.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PmaDMT8O2dub"
      },
      "source": [
        "### Linesearches in practice\n",
        "\n",
        "To find a stepsize satisfying the above criterions, a linesearch needs to access the value and potentially the gradient of the function. So linesearches in optax are implemented as {py:func}`optax.GradientTransformationExtraArgs`, which take the current value, gradient of the objective as well as the function itself. We illustrate this below with {py:func}`optax.scale_by_backtracking_linesearch`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fz3RcdDA3714"
      },
      "outputs": [],
      "source": [
        "# Objective\n",
        "def fun(w):\n",
        "  return jnp.sum(jnp.abs(w))\n",
        "\n",
        "\n",
        "# Linesearch, comment/uncomment the desired one\n",
        "linesearch = optax.scale_by_backtracking_linesearch(max_backtracking_steps=15)\n",
        "# linesearch = optax.scale_by_zoom_linesearch(max_linesearch_steps=15)\n",
        "\n",
        "# Optimizer\n",
        "opt = optax.chain(\n",
        "    optax.sgd(learning_rate=1.0),\n",
        "    # Compare with or without linesearch by commenting this line\n",
        "    linesearch,\n",
        ")\n",
        "\n",
        "# Initialize\n",
        "w = jrd.normal(jrd.PRNGKey(0), (8,))\n",
        "state = opt.init(w)\n",
        "\n",
        "# Run optimization\n",
        "for i in range(16):\n",
        "  v, g = jax.value_and_grad(fun)(w)\n",
        "  print(f'Iteration: {i}, Value:{v:.2e}')\n",
        "  u, state = opt.update(g, state, w, value=v, grad=g, value_fn=fun)\n",
        "  w = w + u\n",
        "\n",
        "print(f'Final value: {fun(w):.2e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CLShtM0s3TAP"
      },
      "source": [
        "To validate the stepsize the linesearch calls the function several times. If a stepsize is accepted, we have then a priori access to the value of the function, and, potentially its gradient. The implementation of the linesearches in optax store the value and the gradient computed by the linesearch to avoid recomputing them at the next step. In practice, the code above can be modified as follows.\n",
        "\n",
        "*Note:*\n",
        "The backtracking linesearch only evaluates the function and does not compute the gradient natively. To make the backtracking linesearch compute and store the gradient at the stepsize taken, we add the flag `store_grad=True`, see below.\n",
        "The zoom linesearch always compute both function and gradient so there is no need to specify an additional flag."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RFH5Llz06iwX"
      },
      "outputs": [],
      "source": [
        "# Objective\n",
        "def fun(w):\n",
        "  return jnp.sum(jnp.abs(w))\n",
        "\n",
        "\n",
        "# Linesearch\n",
        "linesearch = optax.scale_by_backtracking_linesearch(\n",
        "    max_backtracking_steps=15, store_grad=True\n",
        ")\n",
        "# linesearch = optax.scale_by_zoom_linesearch(max_linesearch_steps=15)\n",
        "\n",
        "# Optimizer\n",
        "opt = optax.chain(optax.sgd(learning_rate=1.0), linesearch)\n",
        "\n",
        "# Initialize\n",
        "w = jrd.normal(jrd.PRNGKey(0), (8,))\n",
        "state = opt.init(w)\n",
        "\n",
        "# Run optimization\n",
        "for _ in range(16):\n",
        "  # Replace `v, g = jax.value_and_grad(fun)(w)` by\n",
        "  v, g = optax.value_and_grad_from_state(fun)(w, state=state)\n",
        "  u, state = opt.update(g, state, w, value=v, grad=g, value_fn=fun)\n",
        "  w = w + u\n",
        "\n",
        "print(f'Final value: {fun(w):.2e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fqylj3ANOiYa"
      },
      "source": [
        "### L-BFGS solver\n",
        "\n",
        "Optax combines then the gradient transformation of L-BFGS and a linesearch in `optax.lbfgs()`.\n",
        "\n",
        "We present below a wrapper that combines both into a solver which tries to find the minimizer of a function given\n",
        "1. some initial parameters `init_params`,\n",
        "2. the function to optimize `fun`,\n",
        "3. the instance of the L-BFGS solver considered `opt`,\n",
        "4. a maximal number of iteration `max_iter`,\n",
        "5. a tolerance `tol` on the optimization error measured here as the gradient norm."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3BM2rlAGUx7K"
      },
      "outputs": [],
      "source": [
        "def run_opt(init_params, fun, opt, max_iter, tol):\n",
        "  value_and_grad_fun = optax.value_and_grad_from_state(fun)\n",
        "\n",
        "  def step(carry):\n",
        "    params, state = carry\n",
        "    value, grad = value_and_grad_fun(params, state=state)\n",
        "    updates, state = opt.update(\n",
        "        grad, state, params, value=value, grad=grad, value_fn=fun\n",
        "    )\n",
        "    params = optax.apply_updates(params, updates)\n",
        "    return params, state\n",
        "\n",
        "  def continuing_criterion(carry):\n",
        "    _, state = carry\n",
        "    iter_num = otu.tree_get(state, 'count')\n",
        "    grad = otu.tree_get(state, 'grad')\n",
        "    err = otu.tree_l2_norm(grad)\n",
        "    return (iter_num == 0) | ((iter_num \u003c max_iter) \u0026 (err \u003e= tol))\n",
        "\n",
        "  init_carry = (init_params, opt.init(init_params))\n",
        "  final_params, final_state = jax.lax.while_loop(\n",
        "      continuing_criterion, step, init_carry\n",
        "  )\n",
        "  return final_params, final_state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ohk5KQqfZgWt"
      },
      "source": [
        "We can test the solver on the [Rosenbrock function](https://www.sfu.ca/~ssurjano/rosen.html)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bVrZLiSVZfue"
      },
      "outputs": [],
      "source": [
        "def fun(w):\n",
        "  return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)\n",
        "\n",
        "opt = optax.lbfgs()\n",
        "init_params = jnp.zeros((8,))\n",
        "print(\n",
        "    f'Initial value: {fun(init_params):.2e} '\n",
        "    f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
        ")\n",
        "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
        "print(\n",
        "    f'Final value: {fun(final_params):.2e}, '\n",
        "    f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KinZFIXxbBxy"
      },
      "source": [
        "We may add additional information by simply chaining `optax.lbfgs` with an identity transform that just prints relevant information as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YB4NqbThbdo7"
      },
      "outputs": [],
      "source": [
        "class InfoState(NamedTuple):\n",
        "  iter_num: chex.Numeric\n",
        "\n",
        "\n",
        "def print_info():\n",
        "  def init_fn(params):\n",
        "    del params\n",
        "    return InfoState(iter_num=0)\n",
        "\n",
        "  def update_fn(updates, state, params, *, value, grad, **extra_args):\n",
        "    del params, extra_args\n",
        "\n",
        "    jax.debug.print(\n",
        "        'Iteration: {i}, Value: {v}, Gradient norm: {e}',\n",
        "        i=state.iter_num,\n",
        "        v=value,\n",
        "        e=otu.tree_l2_norm(grad),\n",
        "    )\n",
        "    return updates, InfoState(iter_num=state.iter_num + 1)\n",
        "\n",
        "  return optax.GradientTransformationExtraArgs(init_fn, update_fn)\n",
        "\n",
        "\n",
        "def fun(w):\n",
        "  return jnp.sum(100.0 * (w[1:] - w[:-1] ** 2) ** 2 + (1.0 - w[:-1]) ** 2)\n",
        "\n",
        "\n",
        "opt = optax.chain(print_info(), optax.lbfgs())\n",
        "init_params = jnp.zeros((8,))\n",
        "print(\n",
        "    f'Initial value: {fun(init_params):.2e} '\n",
        "    f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
        ")\n",
        "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
        "print(\n",
        "    f'Final value: {fun(final_params):.2e}, '\n",
        "    f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KZIu7UDveO6D"
      },
      "source": [
        "## Debugging\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LV8CslWpoDDq"
      },
      "source": [
        "### Accessing debug information\n",
        "\n",
        "In some cases, L-BFGS with a linesearch as a solver will fail. Most of the times, the culprit goes down to the linesearch. To debug the solver in such cases, we provide a `verbose` option to the `optax.scale_by_zoom_linesearch`. We show below how to proceed.\n",
        "\n",
        "To demonstrate such bug, we try to minimize the [Zakharov function](https://www.sfu.ca/~ssurjano/zakharov.html) and set the `scale_init_precond` option to `False` (by choosing the default option `scale_init_precond=True`, the algorithm would actually run fine, we just want to showcase the possibility to use debugging in the linesearch here). You'll observe that the final value is is the same as the initial value which points out that the solver failed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NbyxeORif9wf"
      },
      "outputs": [],
      "source": [
        "def fun(w):\n",
        "  ii = jnp.arange(1, len(w) + 1, step=1, dtype=w.dtype)\n",
        "  sum1 = (w**2).sum()\n",
        "  sum2 = (0.5 * ii * w).sum()\n",
        "  return sum1 + sum2**2 + sum2**4\n",
        "\n",
        "opt = optax.lbfgs(scale_init_precond=False)\n",
        "\n",
        "init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])\n",
        "print(\n",
        "    f'Initial value: {fun(init_params)} '\n",
        "    f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params))}'\n",
        ")\n",
        "final_params, _ = run_opt(init_params, fun, opt, max_iter=50, tol=1e-3)\n",
        "print(\n",
        "    f'Final value: {fun(final_params)}, '\n",
        "    f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params))}'\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uwcbY5UXohZB"
      },
      "source": [
        "The default implementation of the linesearch in the code is\n",
        "```\n",
        "scale_by_zoom_linesearch(max_linesearch_steps=20, initial_guess_strategy='one')\n",
        "```\n",
        "To debug we can set the verbose option of the linesearch to `True`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j8FUTkc8o3l2"
      },
      "outputs": [],
      "source": [
        "opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n",
        "  linesearch=optax.scale_by_zoom_linesearch(\n",
        "      max_linesearch_steps=20, verbose=True, initial_guess_strategy='one'\n",
        "  )\n",
        "))\n",
        "\n",
        "init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])\n",
        "print(\n",
        "    f'Initial value: {fun(init_params):.2e} '\n",
        "    f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
        ")\n",
        "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
        "print(\n",
        "    f'Final value: {fun(final_params):.2e}, '\n",
        "    f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nCgpjzCbo7p9"
      },
      "source": [
        "As expected, the linesearch failed at the very first step taking a stepsize that did not ensure a sufficient decrease. Multiple information is displayed. For example, the slope (derivative along the update direction) at the first step is extremely large which explains the difficulties to find an appropriate stepsize. As pointed out in the log above, the first thing to try is to use a larger number of linesearch steps."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nA9WVXykpaKf"
      },
      "outputs": [],
      "source": [
        "opt = optax.chain(print_info(), optax.lbfgs(scale_init_precond=False,\n",
        "  linesearch=optax.scale_by_zoom_linesearch(\n",
        "      max_linesearch_steps=50, verbose=True, initial_guess_strategy='one'\n",
        "  )\n",
        "))\n",
        "\n",
        "init_params = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])\n",
        "print(\n",
        "    f'Initial value: {fun(init_params):.2e} '\n",
        "    f'Initial gradient norm: {otu.tree_l2_norm(jax.grad(fun)(init_params)):.2e}'\n",
        ")\n",
        "final_params, _ = run_opt(init_params, fun, opt, max_iter=100, tol=1e-3)\n",
        "print(\n",
        "    f'Final value: {fun(final_params):.2e}, '\n",
        "    f'Final gradient norm: {otu.tree_l2_norm(jax.grad(fun)(final_params)):.2e}'\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "na-7s1Q2o1Rc"
      },
      "source": [
        "By simply taking a maximum of 50 steps of the linesearch instead of 20, we ensured that the first stepsize taken provided a sufficient decrease and the solver worked well.\n",
        "Additional debugging information can be found in the source code accessible from the docs of {py:func}`optax.scale_by_zoom_linesearch`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "74ZbgzcKoJ0J"
      },
      "source": [
        "### Tips\n",
        "\n",
        "- **LBFGS**\n",
        "  - Selecting a higher `memory_size` in lbfgs may improve performance at a memory and computational cost. No real gains may be perceived after some value.\n",
        "  - `scale_init_precond=True` is standard. It captures a similar scale as other well-known optimization methods like Barzilai Borwein.\n",
        "\n",
        "- **Zoom linesearch**\n",
        "  - Remember there are two conditions to be met (sufficient decrease and small curvature). If the algorithm takes too many linesearch steps, you may try\n",
        "  setting `curv_rtol = jnp.inf`, effectively ignoring the small curvature condition. The resulting algorithm will essentially perform a backtracking linesearch where a valid stepsize is searched by minmizing a quadratic or cubic approximation of the objective (so that would be a potentially faster algorithm than the current implementation of `scale_by_backtracking_linesearch`).\n",
        "  - As pointed above, if the solver gets stuck, try using a larger number of linesearch steps and print debugging information.\n",
        "\n",
        "You may run the solver in double precision by setting `jax.config.update(\"jax_enable_x64\", True)`. If you use double precision, consider augmenting the number of linesearch steps to reach the machine precision (like using `max_linesearch_steps=55`).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T-oGa3P2sCbH"
      },
      "source": [
        "## Contributing and benchmarking\n",
        "\n",
        "Numerous other linesearch could be implemented, as well as other solvers for medium scale problems without stochasticity. Contributions are welcome.\n",
        "\n",
        "If you want to contribute a new solver for medium scale problems like LBFGS, benchmarks would be highly appreciated. We provide below an example of benchmark (which could also be used if you want to test some hyperparameters of the algorithm). We take here the classical Rosenbroke function, but it could be better to expand such benchmarks to e.g. the set of test functions given by [Andrei, 2008](https://camo.ici.ro/journal/vol10/v10a10.pdf)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MagDCuGjsB5x"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "num_fun_calls = 0\n",
        "\n",
        "def register_call():\n",
        "  global num_fun_calls\n",
        "  num_fun_calls += 1\n",
        "\n",
        "def test_hparams(lbfgs_hparams, linesearch_hparams, dimension=512):\n",
        "  global num_fun_calls\n",
        "  num_fun_calls = 0\n",
        "\n",
        "  def fun(x):\n",
        "    jax.debug.callback(register_call)\n",
        "    return jnp.sum((x[1:] - x[:-1] ** 2) ** 2 + (1.0 - x[:-1]) ** 2)\n",
        "\n",
        "  opt = optax.chain(optax.lbfgs(**lbfgs_hparams,\n",
        "    linesearch=optax.scale_by_zoom_linesearch(**linesearch_hparams)\n",
        "    )\n",
        "  )\n",
        "\n",
        "  init_params = jnp.arange(dimension, dtype=jnp.float32)\n",
        "\n",
        "  tic = time.time()\n",
        "  final_params, _ = run_opt(\n",
        "      init_params, fun, opt, max_iter=500, tol=5*1e-5\n",
        "    )\n",
        "  final_params = jax.block_until_ready(final_params)\n",
        "  time_run = time.time() - tic\n",
        "\n",
        "  final_value = fun(final_params)\n",
        "  final_grad_norm = otu.tree_l2_norm(jax.grad(fun)(final_params))\n",
        "  return final_value, final_grad_norm, num_fun_calls, time_run\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7CXMxWsztGf5"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "default_lbfgs_hparams = {'memory_size': 15, 'scale_init_precond': True}\n",
        "default_linesearch_hparams = {\n",
        "    'max_linesearch_steps': 15,\n",
        "    'initial_guess_strategy': 'one'\n",
        "}\n",
        "\n",
        "memory_sizes = [int(2**i) for i in range(7)]\n",
        "times = []\n",
        "calls = []\n",
        "values = []\n",
        "grad_norms = []\n",
        "for m in memory_sizes:\n",
        "  lbfgs_hparams = copy.deepcopy(default_lbfgs_hparams)\n",
        "  lbfgs_hparams['memory_size'] = m\n",
        "  v, g, n, t = test_hparams(lbfgs_hparams, default_linesearch_hparams, dimension=1024)\n",
        "  values.append(v)\n",
        "  grad_norms.append(g)\n",
        "  calls.append(n)\n",
        "  times.append(t)\n",
        "\n",
        "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
        "axs[0].plot(memory_sizes, values)\n",
        "axs[0].set_ylabel('Final values')\n",
        "axs[0].set_yscale('log')\n",
        "axs[1].plot(memory_sizes, grad_norms)\n",
        "axs[1].set_ylabel('Final gradient norms')\n",
        "axs[1].set_yscale('log')\n",
        "axs[2].plot(memory_sizes, calls)\n",
        "axs[2].set_ylabel('Number of function calls')\n",
        "axs[3].plot(memory_sizes, times)\n",
        "axs[3].set_ylabel('Run times')\n",
        "for i in range(4):\n",
        "  axs[i].set_xlabel('Memory size')\n",
        "plt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
